import argparse
import time

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import resnet
import json
import argparse
import sys
import os
# 添加上级目录到 sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from optimizers.lamb import create_lamb_optimizer
from optimizers.ALTO import create_ALTO_optimizer
from adabelief_pytorch import AdaBelief
model_names = sorted(name for name in resnet.__dict__
    if name.islower() and not name.startswith("__")
                     and name.startswith("resnet")
                     and callable(resnet.__dict__[name]))

parser = argparse.ArgumentParser(description='Proper ResNets for CIFAR10 in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32',
                    choices=model_names,
                    help='model architecture: ' + ' | '.join(model_names) +
                    ' (default: resnet32)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=16384, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--optimizer', default='sgd', choices=['sgd', 'adam','adamW', 'lamb', 'ALTO', 'adaBelief'])
parser.add_argument('--log', default='log.json')
args = parser.parse_args()

def main():
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
    model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
    model.cuda()

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR100(root='../datasets', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.CIFAR100(root='../datasets', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    criterion = nn.CrossEntropyLoss().cuda()
    
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), 
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), 
                                     lr=args.lr, 
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'adamW':
        optimizer = torch.optim.AdamW(model.parameters(), 
                                     lr=args.lr, 
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'adaBelief':
        optimizer = AdaBelief(model.parameters(), 
                              lr=args.lr, 
                              betas=(0.9, 0.999))

    elif args.optimizer == 'ALTO':
        optimizer = create_ALTO_optimizer(model, 
                                          lr=args.lr, 
                                          betas=(0.99, 0.9, 0.99),
                                          weight_decay=args.weight_decay)
    elif args.optimizer == 'lamb':
        optimizer = create_lamb_optimizer(model, 
                                          lr=args.lr, 
                                          weight_decay=args.weight_decay)
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))
    
    # 定义学习率调度器
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[args.epochs/4, args.epochs/2, 3*args.epochs/4], gamma=0.1)
    json_data = []
    for epoch in range(args.start_epoch, args.epochs):
        train_loss, train_prec1 = train(train_loader, model, criterion, optimizer)
        val_loss, val_prec1 = validate(val_loader, model, criterion)
        lr_scheduler.step()
        print('Epoch: [{0}]\t'
              'Train Loss {1:.4f}\t'
              'Train Prec@1 {2:.3f}\t'
              'Val Loss {3:.4f}\t'
              'Val Prec@1 {4:.3f}'.format(
                  epoch, train_loss, train_prec1,
                  val_loss, val_prec1))
        epoch_data = {
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_accuracy": train_prec1,
            "val_loss": val_loss,
            "val_accuracy": val_prec1
        }
        json_data.append(epoch_data)
        with open(args.log, "w") as json_file:
            json.dump(json_data, json_file, indent=4)

def train(train_loader, model, criterion, optimizer):
    model.train()
    total_loss = 0.0
    total_prec1 = 0.0
    total_samples = 0

    for i, (input, target) in enumerate(train_loader):
        target = target.cuda()
        input_var = input.cuda()
        target_var = target
        output = model(input_var)
        loss = criterion(output, target_var)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prec1 = accuracy(output.data, target)  # 直接接收返回的浮点数
        total_loss += loss.item() * input.size(0)
        total_prec1 += prec1 * input.size(0)  # 不再使用下标
        total_samples += input.size(0)

    avg_loss = total_loss / total_samples
    avg_prec1 = total_prec1 / total_samples
    return avg_loss, avg_prec1

def validate(val_loader, model, criterion):
    model.eval()
    total_loss = 0.0
    total_prec1 = 0.0
    total_samples = 0

    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()
            output = model(input_var)
            loss = criterion(output, target_var)

            prec1 = accuracy(output.data, target)  # 直接接收返回的浮点数
            total_loss += loss.item() * input.size(0)
            total_prec1 += prec1 * input.size(0)  # 不再使用下标
            total_samples += input.size(0)

    avg_loss = total_loss / total_samples
    avg_prec1 = total_prec1 / total_samples
    return avg_loss, avg_prec1

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    correct_k = correct[:maxk].view(-1).float().sum(0, keepdim=True)
    return correct_k.mul_(100.0 / batch_size).item()  # 返回一个浮点数

if __name__ == '__main__':
    main()
